import numpy as np
from optimizer import inexact_newton, localSGD, fedAC
from function import logistic
from run_help import repeat_run
from data import loada9a
import os


def repeat(F, mu, M, store, T, R_values, rep):
    Losses = {"Newton_w_M": [], "LSGD_w_M": [],
              "MBSGD_w_M": [], "FEDAC1": [], "FEDAC2": [], "Newton": [], "LSGD": [], "MBSGD": []}

    i = 0
    for R in R_values:
        K = int(T/R)
        # Tuning Inexact Newton
        # lr = store["Newton_w_M"][i][0]
        # args = {"F": F, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
        #                 "lr": lr, "momentum": 0.9, "mu": 0, "damp": True}
        # losses = repeat_run(
        #     inexact_newton, "Inexact Newton w/ Momentum", rep, **args)
        # Losses["Newton_w_M"].append(losses)

        # lr = store["Newton"][i][0]
        # args = {"F": F, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
        #                 "lr": lr, "momentum": 0, "mu": 0, "damp": True}
        # losses = repeat_run(inexact_newton, "Inexact Newton", rep, **args)
        # Losses["Newton"].append(losses)

        # Tuning Local SGD
        # lr = store["LSGD_w_M"][i][0]
        # args = {"F": F, "M": M, "K": K, "R": R,
        #                 "mu": 0, "u": None, "lr": lr, "momentum": 0.9, "forHVP": False}
        # losses = repeat_run(localSGD, "Local SGD w/ Momentum", rep, **args)
        # Losses["LSGD_w_M"].append(losses)

        # lr = store["LSGD"][i][0]
        # args = {"F": F, "M": M, "K": K, "R": R,
        #                 "mu": 0, "u": None, "lr": lr, "momentum": 0, "forHVP": False}
        # losses = repeat_run(localSGD, "Local SGD", rep, **args)
        # Losses["LSGD"].append(losses)

        # Tuning MBSGD
        # lr = store["MBSGD_w_M"][i][0]
        # args = {"F": F, "M": M*K, "K": 1, "R": R,
        #                 "mu": 0, "u": None, "lr": lr, "momentum": 0.9, "forHVP": False}
        # losses = repeat_run(
        #     localSGD, "Mini-batch SGD w/ Momentum", rep, **args)
        # Losses["MBSGD_w_M"].append(losses)

        # lr = store["MBSGD"][i][0]
        # args = {"F": F, "M": M*K, "K": 1, "R": R,
        #                 "mu": 0, "u": None, "lr": lr, "momentum": 0, "forHVP": False}
        # losses = repeat_run(localSGD, "Mini-batch SGD", rep, **args)
        # Losses["MBSGD"].append(losses)

        # Tuning FedAC-1
        lr = store["FEDAC1"][i][0]
        args = {"F": F, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": lr, "ver": 1, "u": None}
        losses = repeat_run(fedAC, "FedAC-1", rep, **args)
        Losses["FEDAC1"].append(losses)

        # Tuning FedAC-2
        lr = store["FEDAC2"][i][0]
        args = {"F": F, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": lr, "ver": 2, "u": None}
        losses = repeat_run(fedAC, "FedAC-2", rep, **args)
        Losses["FEDAC2"].append(losses)

        i += 1

    return Losses


F = logistic
mu_values = [1e-2, 1e-4, 1e-5]
M_values = [20, 50, 100, 200]

T = 100  # i.e. K*R = 100
R_values = [1, 5, 10, 25, 50, 100]

for mu in mu_values:
    for M in M_values:
        print(f"[*] Running for mu = {mu} and M = {M}")
        store = np.load(
            f"results/store_{mu}_{M}.npy", allow_pickle=True).item()
        print(f"[*] Running for M = {M}")
        Losses = repeat(F, mu, M, store, T, R_values, 20)
        if os.path.exists(f"results/Losses_{str(mu)}_{str(M)}.npy"):
            Losses_old = np.load(
                f"results/Losses_{str(mu)}_{str(M)}.npy", allow_pickle=True).item()
            for key in Losses.keys():
                if len(Losses[key]) != 0:
                    for i in range(len(R_values)):
                        Losses[key][i] += Losses_old[key][i]

        np.save(f"results/Losses_{str(mu)}_{str(M)}.npy", Losses)
